"""
@Author: Fhz
@Create Date: 2024/1/18 15:13
@File: MTF-LSTM-test.py
@Description: 
@Modify Person Date: 
"""
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset


def feature_scaling(x_seq):
    x_min = x_seq.min()
    x_max = x_seq.max()
    if x_min == x_max:
        x_new = x_min * np.ones(shape=x_seq.shape)
    else:
        x_new = (2 * x_seq - (x_max + x_min)) / (x_max - x_min)
    return x_new, x_min, x_max


def de_feature_scaling(x_new, x_min, x_max):
    x_ori = np.ones(shape=(len(x_max), 80, 44))
    for i in range(len(x_max)):
        for j in range(3):
            if x_min[i, j] == x_max[i, j]:
                x_ori[i, :, j] = x_min[i, j]
            else:
                x_ori[i, :, j] = (x_new[i, :, j] * (x_max[i, j] - x_min[i, j]) + x_max[i, j] + x_min[i, j]) / 2

    return x_ori


def data_diff(data):
    data_diff = np.diff(data)
    data_0 = data[0]
    return data_0, data_diff


def de_data_diff(data_0, data_diff):
    data = np.ones(shape=(len(data_diff), 80, 44))
    data[:, 0, :] = data_0
    for i in range(79):
        data[:, i + 1, :] = data[:, i, :] + data_diff[:, i, :]

    return data


def dataNormal(seq):
    seq_len = len(seq)
    seq_norm = np.zeros(shape=(seq_len, 79, 44))
    seq_norm_feature = np.zeros(shape=(seq_len, 3, 44))

    for i in range(seq_len):
        for j in range(44):
            seq_tmp = seq[i, :, j]  # initial seq
            seq_tmp_FS, seq_tmp_min, seq_tmp_max = feature_scaling(seq_tmp)  # feature scaling
            seq_tmp_0, seq_tmp_diff = data_diff(seq_tmp_FS)  # seq diff
            seq_norm[i, :, j] = seq_tmp_diff  # store norm data

            # store norm feature data
            seq_norm_feature[i, 0, j] = seq_tmp_min
            seq_norm_feature[i, 1, j] = seq_tmp_max
            seq_norm_feature[i, 2, j] = seq_tmp_0

    return seq_norm, seq_norm_feature


def get_train_dataset(train_data, batch_size):
    # 预测2s，将数据进行动态窗口移动，进行泛化
    x = train_data[:, :49, :]
    y = train_data[:, 49:, :]

    x_data = torch.from_numpy(x.copy())
    y_data = torch.from_numpy(y.copy())

    x_data = x_data.to(torch.float32)
    y_data = y_data.to(torch.float32)

    train_dataset = TensorDataset(x_data, y_data)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True)

    return train_loader


def get_test_dataset(test_data, test_seq_NF, batch_size):
    x_data = torch.from_numpy(test_data.copy())
    x_data = x_data.to(torch.float32)

    y_data = torch.from_numpy(test_seq_NF.copy())
    y_data = y_data.to(torch.float32)

    test_dataset = TensorDataset(x_data, y_data)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, shuffle=True)

    return test_loader


def LoadData(num):
    valid_x = np.load(file="../data_process/processed_dataset/X_valid_{}.npy".format(num))
    valid_y = np.load(file="../data_process/processed_dataset/y_valid_{}.npy".format(num))

    return valid_x, valid_y


class lstm_encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=4):
        super(lstm_encoder, self).__init__()
        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(batch_first=True,
                            input_size=self.input_size,
                            hidden_size=self.hidden_size,
                            num_layers=self.num_layers,
                            dropout=0.2
                            )

    def forward(self, input):
        lstm_out, self.hidden = self.lstm(input)
        return lstm_out, self.hidden


class lstm_decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=4):
        super(lstm_decoder, self).__init__()
        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(batch_first=True,
                            input_size=self.input_size,
                            hidden_size=self.hidden_size,
                            num_layers=self.num_layers,
                            dropout=0.2
                            )
        self.fc = nn.Linear(self.hidden_size, self.input_size)

    def forward(self, input, encoder_hidden_states):
        lstm_out, self.hidden = self.lstm(input.unsqueeze(1), encoder_hidden_states)
        output = self.fc(lstm_out.squeeze(1))
        return output, self.hidden


class MyLstm(nn.Module):
    def __init__(self, input_size=44, hidden_size=256, batch_size=1000, target_len=30, TR=0.1):
        super(MyLstm, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.target_len = target_len
        self.TR = TR

        self.encoder = lstm_encoder(input_size=self.input_size, hidden_size=self.hidden_size)
        self.decoder = lstm_decoder(input_size=self.input_size, hidden_size=self.hidden_size)

    def forward(self, input, target, training_prediction="recursive"):

        encoder_output, encoder_hidden = self.encoder(input)
        decoder_input = input[:, -1, :]
        decoder_hidden = encoder_hidden
        # print(decoder_hidden[0].shape, decoder_hidden[1].shape)

        outputs = torch.zeros(input.shape[0], self.target_len, input.shape[2])

        if training_prediction == "recursive":
            # recursive
            for t in range(self.target_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[:, t, :] = decoder_output
                decoder_input = decoder_output

        if training_prediction == "teacher_forcing":
            # teacher_forcing
            for t in range(self.target_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[:, t, :] = decoder_output
                decoder_input = target[:, t, :]

        if training_prediction == "mixed_teacher_forcing":
            # mixed_teacher_forcing
            teacher_forcing_ratio = self.TR
            for t in range(self.target_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[:, t, :] = decoder_output

                if random.random() < teacher_forcing_ratio:
                    decoder_input = target[:, t, :]
                else:
                    decoder_input = decoder_output

        return outputs


class RMSELoss(torch.nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()

    def forward(self, x, y):
        criterion = nn.MSELoss(reduction="none")
        loss_sum = torch.sqrt(torch.sum(criterion(x, y), axis=-1))
        loss = torch.mean(loss_sum)

        return loss


class RMSEsum(torch.nn.Module):
    def __init__(self):
        super(RMSEsum, self).__init__()

    def forward(self, x, y):
        criterion = nn.MSELoss(reduction="none")
        loss_sum = torch.sqrt(torch.sum(criterion(x, y), axis=-1))

        return loss_sum


if __name__ == '__main__':

    result44 = np.zeros(shape=(10, 9, 3, 5))

    for dataset_num in range(10):
        seq_valid, y_valid = LoadData(dataset_num)

        x_norm_valid, x_norm_valid_feature = dataNormal(seq_valid)

        batch_size = 1024
        epochs = 1
        learning_rate = 0.001

        valid_loader = get_test_dataset(x_norm_valid, x_norm_valid_feature, batch_size)

        for tr in range(9):
            TR = (tr + 1) / 10

            for times in range(3):
                model_name = "TP_models/MTF-LSTM_D{}_R{}_T{}.pkl".format(dataset_num, tr + 1, times + 1)
                print(model_name)

                model = MyLstm(TR=TR)
                rmse_loss = RMSELoss()
                rmse_sum = RMSEsum()
                optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001)

                device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
                model.to(device)
                print(device)

                loss_train = []
                loss_test = []
                loss_test_history = 1000

                MR1 = 0
                MR2 = 0
                for i in range(epochs):

                    if i == epochs - 1:
                        # valid loader
                        loss_tmp = []
                        loss_1s_tmp = []
                        loss_2s_tmp = []
                        loss_3s_tmp = []

                        for batch_idx, (x_seq, x_seq_NF) in enumerate(valid_loader):
                            x_data = x_seq.to(device)
                            x_data = x_data.to(torch.float32)

                            x_seq_NF = x_seq_NF.to(device)
                            x_seq_NF = x_seq_NF.to(torch.float32)

                            # muti-steps prediction
                            x_data_ori = x_data.clone()
                            x_tmp = x_data[:, :49, :]
                            y_tmp = x_data[:, 49:, :]

                            model.load_state_dict(torch.load(model_name))

                            with torch.no_grad():
                                pred = model(x_tmp, y_tmp, training_prediction="recursive")
                                pred = pred.to(device)

                            x_data[:, 49:, :] = pred
                            x_seq_NF = x_seq_NF.cpu().numpy()

                            pred_seq_np = x_data.cpu().numpy()
                            pred_seq_dediff = de_data_diff(x_seq_NF[:, 2, :], pred_seq_np)
                            pred_seq_ori = de_feature_scaling(pred_seq_dediff, x_seq_NF[:, 0, :], x_seq_NF[:, 1, :])

                            x_data_ori_np = x_data_ori.cpu().numpy()
                            x_data_ori_dediff = de_data_diff(x_seq_NF[:, 2, :], x_data_ori_np)
                            x_data_oo = de_feature_scaling(x_data_ori_dediff, x_seq_NF[:, 0, :], x_seq_NF[:, 1, :])

                            pred_seq_ori_torch = torch.from_numpy(pred_seq_ori)
                            x_data_oo_torch = torch.from_numpy(x_data_oo)

                            pred_seq_ori_torch = pred_seq_ori_torch.to(torch.float32)
                            x_data_oo_torch = x_data_oo_torch.to(torch.float32)

                            # Get RMSE_loss of each prediction step
                            loss_1s = rmse_loss(x_data_oo_torch[:, 59, :], pred_seq_ori_torch[:, 59, :])
                            loss_2s = rmse_loss(x_data_oo_torch[:, 69, :], pred_seq_ori_torch[:, 69, :])
                            loss_3s = rmse_loss(x_data_oo_torch[:, 79, :], pred_seq_ori_torch[:, 79, :])

                            loss = rmse_loss(x_data_oo_torch[:, 50:, :], pred_seq_ori_torch[:, 50:, :])

                            loss_1s_tmp.append(loss_1s)
                            loss_2s_tmp.append(loss_2s)
                            loss_3s_tmp.append(loss_3s)

                            loss_tmp.append(loss.item())

                            rmse_5s_ = rmse_sum(x_data_oo_torch[:, 79, :], pred_seq_ori_torch[:, 79, :])
                            rmse_5s = np.array(rmse_5s_)

                            MR1 = MR1 + np.sum(rmse_5s >= 2)
                            MR2 = MR2 + np.sum(rmse_5s <= 2)

                        loss_1s_tmp_np = np.array(loss_1s_tmp)
                        loss_2s_tmp_np = np.array(loss_2s_tmp)
                        loss_3s_tmp_np = np.array(loss_3s_tmp)

                        loss_1s_mean = loss_1s_tmp_np.mean()
                        loss_2s_mean = loss_2s_tmp_np.mean()
                        loss_3s_mean = loss_3s_tmp_np.mean()

                        loss_tmp_np = np.array(loss_tmp)
                        loss_mean = loss_tmp_np.mean()

                        result44[dataset_num, tr, times, 0] = loss_1s_mean
                        result44[dataset_num, tr, times, 1] = loss_2s_mean
                        result44[dataset_num, tr, times, 2] = loss_3s_mean
                        result44[dataset_num, tr, times, 3] = loss_mean
                        result44[dataset_num, tr, times, 4] = MR2 / (MR1 + MR2)

    np.save(file="LSTM_result.npy", arr=result44)

